import torch
import numpy as np
from sklearn.model_selection import train_test_split
from typing import NamedTuple, Optional

class PVTrainDataSet(NamedTuple):
    treatment: torch.Tensor
    treatment_proxy: torch.Tensor
    outcome_proxy: torch.Tensor
    outcome: torch.Tensor
    backdoor: Optional[torch.Tensor]

class PVTestDataSet(NamedTuple):
    treatment: torch.Tensor
    treatment_proxy: torch.Tensor
    outcome_proxy: torch.Tensor
    outcome: torch.Tensor
    backdoor: Optional[torch.Tensor]


def split_train_data(train_data: PVTrainDataSet, split_ratio=0.5):
    if split_ratio < 0.0:
        return train_data, train_data

    n_data = train_data[0].shape[0]
    idx_train_1st, idx_train_2nd = train_test_split(np.arange(n_data), train_size=split_ratio,random_state=42)

    def get_data(data, idx):
        return data[idx] if data is not None else None

    train_1st_data = PVTrainDataSet(*[get_data(data, idx_train_1st) for data in train_data])
    train_2nd_data = PVTrainDataSet(*[get_data(data, idx_train_2nd) for data in train_data])

    return train_1st_data, train_2nd_data


def extract_data_from_npz(npz_file_path,totensor = False):
    data = np.load(npz_file_path)

    if totensor:
        train_a = torch.Tensor(data['train_a'])
        train_z = torch.Tensor(data['train_z'])
        train_w = torch.Tensor(data['train_w'])
        train_y = torch.Tensor(data['train_y'])
        test_a = torch.Tensor(data['test_a'])
        test_z = torch.Tensor(data['test_z'])
        test_w = torch.Tensor(data['test_w'])
        test_y = torch.Tensor(data['test_y'])
    else:
        train_a = data['train_a']
        train_z = data['train_z']
        train_w = data['train_w']
        train_y = data['train_y']
        test_a = data['test_a']
        test_z = data['test_z']
        test_w = data['test_w']
        test_y = data['test_y']

    pv_train_data = PVTrainDataSet(
        treatment=train_a,
        treatment_proxy=train_z,
        outcome_proxy=train_w,
        outcome=train_y,
        backdoor=None  
    )

    pv_test_data = PVTestDataSet(
        treatment=test_a,
        treatment_proxy=test_z,
        outcome_proxy=test_w,
        outcome=test_y,
        backdoor=None  
    )
    return pv_train_data,pv_test_data


def extract_effect_from_npz(npz_file_path):
    data = np.load(npz_file_path)
    return data['do_A'],data['gt_EY_do_A']